source("utils.R")
load("data/07.doc_surf_seas_pred.Rdata")Analyse predictons of seasonal surface DOC
Model evaluation, interpretation & projections
Set-up and load data
Model evaluation
Rsquares
Black dots on the R² boxplots show the actual values.
# Unnest predictions
preds <- res %>% select(fold, season, preds) %>% unnest(preds)
# Compute Rsquare for each fold of each CV type
rsquares <- preds %>%
group_by(season, fold) %>%
rsq(truth = log_doc_surf, estimate = .pred)
# Distribution of Rsquares by season
rsquares %>% split(.$season) %>% map(summary)$`1`
season fold .metric .estimator
Length:10 Length:10 Length:10 Length:10
Class :character Class :character Class :character Class :character
Mode :character Mode :character Mode :character Mode :character
.estimate
Min. :0.7783
1st Qu.:0.8121
Median :0.8285
Mean :0.8303
3rd Qu.:0.8520
Max. :0.8940
$`2`
season fold .metric .estimator
Length:10 Length:10 Length:10 Length:10
Class :character Class :character Class :character Class :character
Mode :character Mode :character Mode :character Mode :character
.estimate
Min. :0.4867
1st Qu.:0.6505
Median :0.7906
Mean :0.7459
3rd Qu.:0.8446
Max. :0.8958
$`3`
season fold .metric .estimator
Length:10 Length:10 Length:10 Length:10
Class :character Class :character Class :character Class :character
Mode :character Mode :character Mode :character Mode :character
.estimate
Min. :0.4037
1st Qu.:0.5679
Median :0.6516
Mean :0.6237
3rd Qu.:0.7150
Max. :0.7722
$`4`
season fold .metric .estimator
Length:10 Length:10 Length:10 Length:10
Class :character Class :character Class :character Class :character
Mode :character Mode :character Mode :character Mode :character
.estimate
Min. :0.5235
1st Qu.:0.6418
Median :0.6817
Mean :0.6897
3rd Qu.:0.7499
Max. :0.7980
# Plot Rsquares values
ggplot(rsquares) +
geom_boxplot(aes(x = season, y = .estimate, group = season, colour = season)) +
geom_jitter(aes(x = season, y = .estimate), size = 0.5, width = 0.1) +
scale_y_continuous(limits = c(0, 1), expand = c(0, 0)) +
labs(x = "Season", y = "R²", colour = "Season")Predictions VS truth
Plot pred VS truth on the test part of each fold of each CV type.
preds %>%
ggplot() +
geom_point(aes(x = log_doc_surf, y = .pred, colour = season)) +
geom_abline(intercept = 0, slope = 1, colour = "red") +
coord_fixed() +
facet_wrap(season~fold, ncol = 5)Now let’s focus on a representative fold for each CV type.
# Find the one closer to the median and plot it
repres_fold <- rsquares %>%
group_by(season) %>%
mutate(diff = abs(.estimate - median(.estimate))) %>%
filter(diff == min(diff)) %>%
slice_head(n = 1)
repres_fold %>%
select(season, fold) %>%
left_join(preds, by = join_by(season, fold)) %>%
ggplot() +
geom_point(aes(x = log_doc_surf, y = .pred, colour = season)) +
geom_abline(intercept = 0, slope = 1, colour = "red") +
coord_fixed() +
labs(title = "Pred VS truth for a representative fold") +
facet_wrap(~season, ncol = 2)Model interpretation
Variable importance
Variable importance for each fold of each CV type.
# Unnest variable importance
full_vip <- res %>%
select(season, fold, importance) %>%
unnest(importance) %>%
mutate(variable = forcats::fct_reorder(variable, dropout_loss))
# Variable importance across folds
full_vip %>%
filter(variable != "_full_model_") %>%
ggplot() +
geom_vline(data = full_vip %>% filter(variable == "_full_model_"), aes(xintercept = mean(dropout_loss)), colour = "grey", linewidth = 2) +
geom_boxplot(aes(x = dropout_loss, y = variable, colour = season)) +
labs(x = "RMSE after permutations") +
facet_grid(fold~season)Now let’s take the mean across folds of each CV type.
full_vip %>%
filter(variable != "_full_model_") %>%
group_by(season, fold, variable) %>%
summarise(dropout_loss = mean(dropout_loss), .groups = "drop") %>%
ggplot() +
geom_vline(data = full_vip %>% filter(variable == "_full_model_"), aes(xintercept = mean(dropout_loss)), colour = "grey", linewidth = 2) +
geom_boxplot(aes(x = dropout_loss, y = variable, colour = season)) +
labs(x = "Mean RMSE after permutations across CV folds") +
facet_wrap(~season, ncol = 4)Partial dependence plots
Finally, let’s have a look at partial dependence plots.
blue line: prediction mean across cp profiles
grey ribbon: prediction sd across centered cp profiles
# Variables for which to plot pdp
n_pdp <- 3
vars_pdp <- full_vip %>%
filter(variable != "_full_model_") %>%
mutate(variable = as.character(variable)) %>%
group_by(season, variable) %>%
summarise(dropout_loss = mean(dropout_loss), .groups = "drop") %>%
arrange(desc(dropout_loss)) %>%
group_by(season) %>%
slice_head(n = n_pdp)
# Unnest cp_profiles
cp_profiles <- res %>% select(season, fold, cp_profiles) %>% unnest(cp_profiles)
## Let’s generate averaged cp profile across folds for each cv-type and propagating uncertainties.
## The difficulty is that x values differ between each fold, the solution is to interpolate yhat on a common set of x values across folds.
## Steps as follows for each season and each variable
## 1- compute the mean and spread of cp profiles within each fold
## 2- interpolate yhat value and spread within each fold using a common set of x values
## 3- perform a weighted average of yhat value and spread, using 1/var as weights
# Get names of folds, for later use
folds <- sort(unique(full_vip$fold))
# Apply on each season and variable
mean_pdp <- lapply(1:nrow(vars_pdp), function(r){
# Get variable and cvtype
var_name <- vars_pdp[r,]$variable
season_name <- vars_pdp[r,]$season
## Get corresponding CP profiles, compute mean and spread for each fold (step 1)
d_pdp <- cp_profiles %>%
filter(season == season_name & `_vname_` == var_name) %>%
select(season, fold, `_yhat_`, `_vname_`, `_ids_`, all_of(var_name)) %>%
arrange(`_ids_`, across(all_of(var_name))) %>%
# center each cp profiles across fold, variable and ids
group_by(season, fold, `_vname_`, `_ids_`) %>%
mutate(yhat_cent = `_yhat_` - mean(`_yhat_`)) %>% # center cp profiles
ungroup() %>%
# compute mean and sd of centered cp profiles for each fold and value of the variable of interest
group_by(season, fold, across(all_of(var_name))) %>%
summarise(
yhat_loc = mean(`_yhat_`), # compute mean of profiles
yhat_spr = sd(yhat_cent), # compute sd of cp profiles
.groups = "keep"
) %>%
ungroup() %>%
setNames(c("season", "fold", "x", "yhat_loc", "yhat_spr"))
## Interpolate yhat values and spread on a common x distribution (step 2)
# Regularise across folds: need a common x distribution, and interpolate y on this new x
new_x <- quantile(d_pdp$x, probs = seq(0, 1, 0.01), names = FALSE)
# x is different within each fold, so interpolation is performed on each fold
int_pdp <- lapply(1:length(folds), function(i){
# Get data corresponding to this fold
fold_name <- folds[i]
this_fold <- d_pdp %>% filter(fold == fold_name)
# Extract original x values
x <- this_fold$x
# Extract values to interpolate (yhat_loc and yhat_spr)
yhat_loc <- this_fold$yhat_loc
yhat_spr <- this_fold$yhat_spr
# Interpolate yhat_loc and yhat_spr on new x values
int <- tibble(
x = new_x,
yhat_loc = castr::interpolate(x = x, y = yhat_loc, xout = new_x),
yhat_spr = castr::interpolate(x = x, y = yhat_spr, xout = new_x),
) %>%
mutate(
season = season_name,
fold = fold_name,
var_name = var_name,
.before = x
)
# Return the result
return(int)
}) %>%
bind_rows()
## Across fold, compute the weighted mean, using 1/var as weights (step 3)
mean_pdp <- int_pdp %>%
group_by(season, var_name, x) %>%
summarise(
yhat_loc = wtd.mean(yhat_loc, weights = 1/(yhat_spr)^2),
yhat_spr = wtd.mean(yhat_spr, weights = 1/(yhat_spr)^2),
.groups = "drop"
) %>%
arrange(x)
# Return the result
return(mean_pdp)
}) %>%
bind_rows()
# Arrange in order of most important variables
mean_pdp <- vars_pdp %>%
rename(var_name = variable) %>%
left_join(mean_pdp, by = join_by(season, var_name)) %>%
mutate(var_name = fct_inorder(var_name)) %>%
select(-dropout_loss)
# Plot it!
ggplot(mean_pdp) +
geom_path(aes(x = x, y = yhat_loc, colour = season)) +
geom_ribbon(aes(x = x, ymin = yhat_loc - yhat_spr, ymax = yhat_loc + yhat_spr, fill = season), alpha = 0.2) +
facet_wrap(~var_name, scales = "free_x")New predictions
Collect predictions
# Unnest new predictions (i.e. projections)
new_preds <- res %>%
select(fold, cv_type, new_preds) %>%
unnest(new_preds) %>%
# Apply exp to predictions as we predicted log(doc)
mutate(pred_doc = exp(pred_doc_log), .after = pred_doc_log) %>%
select(season, fold, lon, lat, contains("doc")) %>%
mutate(season = as.character(season))
# Join projections with R² value for each fold and each season.
new_preds_strat <- new_preds %>% left_join(rsquares %>% select(season, fold, rsq = .estimate), by = join_by(season, fold))
## Average by pixel
# Stratified
strat_proj <- new_preds_strat %>%
group_by(lon, lat, season) %>%
summarise(
doc_avg = wtd.mean(pred_doc, weights = rsq, na.rm = TRUE),
doc_sd = sqrt(wtd.var(pred_doc, weights = rsq, na.rm = TRUE)),
.groups = "drop"
)
# Generate common colour bar limits for both CV types
doc_avg_lims <- c(
min(c(strat_proj$doc_avg)),
max(c(strat_proj$doc_avg))
)
doc_sd_lims <- c(
min(c(strat_proj$doc_sd)),
max(c(strat_proj$doc_sd))
)Maps
Stratified CV
ggplot(strat_proj) +
geom_polygon(data = world, aes(x = lon, y = lat, group = group), fill = "grey") +
geom_raster(aes(x = lon, y = lat, fill = doc_avg)) +
ggplot2::scale_fill_viridis_c(option = "F", limits = doc_avg_lims, trans = "log1p") +
labs(title = "DOC avg from stratified CV") +
coord_quickmap(expand = 0) +
facet_wrap(~season)ggplot(strat_proj) +
geom_polygon(data = world, aes(x = lon, y = lat, group = group), fill = "grey") +
geom_raster(aes(x = lon, y = lat, fill = doc_sd)) +
ggplot2::scale_fill_viridis_c(option = "E", limits = doc_sd_lims, trans = "log1p") +
labs(title = "DOC sd from stratified CV") +
coord_quickmap(expand = 0) +
facet_wrap(~season)Seasonal cycle
seas_amp <- strat_proj %>%
group_by(lon, lat) %>%
summarise(
seas_amp = max(doc_avg, na.rm = TRUE) - min(doc_avg, na.rm = TRUE),
seas_var = var(doc_avg, na.rm = TRUE),
.groups = "drop"
)
ggplot(seas_amp) +
geom_polygon(data = world, aes(x = lon, y = lat, group = group), fill = "grey") +
geom_raster(aes(x = lon, y = lat, fill = seas_amp)) +
scale_fill_viridis_c() +
labs(title = "Seasonal amplitude") +
coord_quickmap(expand = 0) ggplot(seas_amp) +
geom_polygon(data = world, aes(x = lon, y = lat, group = group), fill = "grey") +
geom_raster(aes(x = lon, y = lat, fill = seas_var)) +
scale_fill_viridis_c(trans = "log1p", na.value = NA) +
labs(title = "Seasonal variance") +
coord_quickmap(expand = 0)